{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 使用 PAI Python SDK 训练和部署 XGBoost 模型\n",
    "\n",
    "\n",
    "[XGBoost](https://xgboost.readthedocs.io/) 是基于决策树的梯度提升算法([Gradient Boosting](https://en.wikipedia.org/wiki/Gradient_boosting))的高效工程实现，是一个流行的机器学习库，它能够处理大的数据集合，并且做了许多训练性能优化工作。\n",
    "\n",
    "在这个教程示例中，我们将使用PAI Python SDK，在PAI上完成XGBoost模型的训练，然后将输出的模型部署为在线推理服务，并进行调用测试。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 费用说明\n",
    "\n",
    "本示例将会使用以下云产品，并产生相应的费用账单：\n",
    "\n",
    "- PAI-DLC：运行训练任务，详细计费说明请参考[PAI-DLC计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-dlc)\n",
    "- PAI-EAS：部署推理服务，详细计费说明请参考[PAI-EAS计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-eas)\n",
    "- OSS：存储训练任务输出的模型、TensorBoard日志等，详细计费说明请参考[OSS计费概述](https://help.aliyun.com/zh/oss/product-overview/billing-overview)\n",
    "\n",
    "\n",
    "> 通过参与云产品免费试用，使用**指定资源机型**提交训练作业或是部署推理服务，可以免费试用PAI产品，具体请参考[PAI免费试用](https://help.aliyun.com/zh/pai/product-overview/free-quota-for-new-users)。\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Step1: 准备工作\n",
    "\n",
    "我们需要首先安装 PAI Python SDK 以运行本示例。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade pai"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "SDK 需要配置访问阿里云服务需要的 AccessKey，以及当前使用的工作空间和OSS Bucket。在 PAI SDK 安装之后，通过在 **命令行终端** 中执行以下命令，按照引导配置密钥，工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```\n",
    "\n",
    "我们可以通过以下代码验证当前的配置。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 验证安装\n",
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "\n",
    "sess = get_default_session()\n",
    "\n",
    "assert sess.workspace_name is not None\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step2: 准备数据集\n",
    "\n",
    "我们将使用[Breast Cancer数据集](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic))，训练和测试XGBoost模型。准备数据集的步骤如下：\n",
    "\n",
    "1. 通过 `scikit-learn` 下载和拆分 Breast Cancer 数据集，使用 `csv` 格式保存到本地。\n",
    "\n",
    "2. 将本地数据集上传到OSS Bucket上，获得数据集的OSS URI，供云上执行的训练作业使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "使用SKLearn下载和拆分数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train data local path: ./train_data/train.csv\n",
      "test data local path: ./test_data/train.csv\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "# 安装 sklearn， 用于数据集下载和切分\n",
    "!{sys.executable} -m pip install --quiet  scikit-learn\n",
    "\n",
    "# 创建数据集目录\n",
    "!mkdir -p ./train_data\n",
    "!mkdir -p ./test_data\n",
    "\n",
    "from sklearn import datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "df = datasets.load_breast_cancer(as_frame=True)\n",
    "\n",
    "train, test = train_test_split(df.frame, test_size=0.3)\n",
    "\n",
    "train_data_local = \"./train_data/train.csv\"\n",
    "test_data_local = \"./test_data/train.csv\"\n",
    "\n",
    "train.to_csv(train_data_local, index=False)\n",
    "test.to_csv(test_data_local, index=False)\n",
    "\n",
    "print(f\"train data local path: {train_data_local}\")\n",
    "print(f\"test data local path: {test_data_local}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "上传数据集到OSS Bucket。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 上传数据集到OSS Bucket\n",
    "from pai.common.oss_utils import upload\n",
    "\n",
    "\n",
    "# 上传训练数据到OSS\n",
    "train_data = upload(\n",
    "    train_data_local,\n",
    "    \"pai/xgboost-example/train_data/\",\n",
    "    sess.oss_bucket,\n",
    ")\n",
    "\n",
    "\n",
    "test_data = upload(\n",
    "    test_data_local,\n",
    "    \"pai/xgboost-example/test_data/\",\n",
    "    sess.oss_bucket,\n",
    ")\n",
    "\n",
    "print(f\"train data: {train_data}\")\n",
    "print(f\"test data: {test_data}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step3: 提交训练作业\n",
    "\n",
    "通过PAI Python SDK提供`Estimator`，用户可以将训练脚本，提交到PAI创建一个训练作业，获得输出模型，主要流程包括：\n",
    "\n",
    "1. 用户编写训练作业脚本\n",
    "\n",
    "训练脚本负责模型代码的编写，它需要遵循PAI训练作业的规则获取作业超参，读取输入数据，并且将需要保存模型到指定的输出目录。\n",
    "\n",
    "2. 构建`Estimator`对象\n",
    "\n",
    "通过`Estimator` API，用户配置训练作业使用的脚本，镜像，超参，以及机器实例类型等信息。\n",
    "本地的脚本会有Estimator上传到OSS Bucket，然后加载到训练作业内。\n",
    "\n",
    "3. 调用`Estimator.fit`API提交作业\n",
    "\n",
    "通过`.fit`提交一个训练作业，默认`.fit`方法会等到作业停止之后，才会退出，作业结束后，用户可以通过`estimator.model_data()`获得输出模型OSS URI路径。\n",
    "\n",
    "更加完整的介绍请参考 [文档: 提交训练作业](https://pai-sdk.oss-cn-shanghai.aliyuncs.com/pai/doc/latest/user-guide/estimator.html)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们通过XGboost提供的SKlearn API，构建了一个XGBoost的训练脚本：\n",
    "\n",
    "- 训练作业默认接收两个输入Channel: train 和 test，训练脚本会从 `/ml/input/data/{channel_name}` 中读取训练数据。\n",
    "\n",
    "- 训练结束之后，训练脚本需要将模型写出到到 `/ml/output/model` 目录下。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!mkdir -p xgb_src/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting xgb_src/train.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile xgb_src/train.py\n",
    "\n",
    "\n",
    "import argparse\n",
    "import logging\n",
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "from xgboost import XGBClassifier\n",
    "\n",
    "logging.basicConfig(format=\"%(levelname)s:%(message)s\", level=logging.INFO)\n",
    "\n",
    "TRAINING_BASE_DIR = \"/ml/\"\n",
    "TRAINING_OUTPUT_MODEL_DIR = os.path.join(TRAINING_BASE_DIR, \"output/model/\")\n",
    "\n",
    "\n",
    "def load_dataset(channel_name):\n",
    "    path = os.path.join(TRAINING_BASE_DIR, \"input/data\", channel_name)\n",
    "    if not os.path.exists(path):\n",
    "        return None, None\n",
    "\n",
    "    # use first file in the channel dir.\n",
    "    file_name = next(\n",
    "        iter([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]),\n",
    "        None,\n",
    "    )\n",
    "    if not file_name:\n",
    "        logging.warning(f\"Not found input file in channel path: {path}\")\n",
    "        return None, None\n",
    "\n",
    "    file_path = os.path.join(path, file_name)\n",
    "    df = pd.read_csv(\n",
    "        filepath_or_buffer=file_path,\n",
    "        sep=\",\",\n",
    "    )\n",
    "\n",
    "    train_y = df[\"target\"]\n",
    "    train_x = df.drop([\"target\"], axis=1)\n",
    "    return train_x, train_y\n",
    "\n",
    "\n",
    "def main():\n",
    "    parser = argparse.ArgumentParser(description=\"XGBoost train arguments\")\n",
    "    # 用户指定的任务参数\n",
    "    parser.add_argument(\n",
    "        \"--n_estimators\", type=int, default=500, help=\"The number of base model.\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--objective\", type=str, help=\"Objective function used by XGBoost\"\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--max_depth\", type=int, default=3, help=\"The maximum depth of the tree.\"\n",
    "    )\n",
    "\n",
    "    parser.add_argument(\n",
    "        \"--eta\",\n",
    "        type=float,\n",
    "        default=0.2,\n",
    "        help=\"Step size shrinkage used in update to prevents overfitting.\",\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--eval_metric\",\n",
    "        type=str,\n",
    "        default=None,\n",
    "        help=\"Evaluation metrics for validation data\"\n",
    "    )\n",
    "\n",
    "    args, _ = parser.parse_known_args()\n",
    "\n",
    "    # 加载数据集\n",
    "    train_x, train_y = load_dataset(\"train\")\n",
    "    print(\"Train dataset: train_shape={}\".format(train_x.shape))\n",
    "    test_x, test_y = load_dataset(\"test\")\n",
    "    if test_x is None or test_y is None:\n",
    "        print(\"Test dataset not found\")\n",
    "        eval_set = [(train_x, train_y)]\n",
    "    else:\n",
    "        eval_set = [(train_x, train_y), (test_x, test_y)]\n",
    "\n",
    "    clf = XGBClassifier(\n",
    "        max_depth=args.max_depth,\n",
    "        eta=args.eta,\n",
    "        n_estimators=args.n_estimators,\n",
    "        objective=args.objective,\n",
    "    )\n",
    "    clf.fit(train_x, train_y, eval_set=eval_set, eval_metric=args.eval_metric)\n",
    "\n",
    "    model_path = os.environ.get(\"PAI_OUTPUT_MODEL\")\n",
    "    os.makedirs(model_path, exist_ok=True)\n",
    "    clf.save_model(os.path.join(model_path, \"model.json\"))\n",
    "    print(f\"Save model succeed: model_path={model_path}/model.json\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用Estimator提交训练作业\n",
    "\n",
    "通过 Estimator， 我们将以上构建的训练脚本 (xgb_src/train.py) 上传到 OSS上，通过`fit` 提交一个在云端执行XGBoost训练作业。 fit API接收的inputs分别是之前上传的训练和测试的数据，会被挂载到作业容器中（分别挂载到 `/ml/input/data/{channel_name}/`)，供训练脚本读取输入数据。\n",
    "\n",
    "提交之后，SDK 会打印作业的详情URL，并且打印作业日志，直到作业退出（成功，失败，或是停止）。用户可以点击作业URL查看任务详情，执行日志，模型的Metric，机器资源使用率等信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "registry.cn-beijing.aliyuncs.com/pai-dlc/xgboost-training:1.6.0-cpu-py36-ubuntu18.04\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Uploading file: /var/folders/jn/9tcbd4h56z5g3wbbd5sms38m0000gp/T/tmpwwk7er5t/source.tar.gz: 100%|██████████| 1.39k/1.39k [00:00<00:00, 8.84kB/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "View the job detail by accessing the console URI: https://pai.console.aliyun.com/?regionId=cn-beijing&workspaceId=90914#/training/jobs/train10qiryeueit\n",
      "TrainingJob launch starting\n",
      "KUBERNETES_PORT=tcp://10.192.0.1:443\n",
      "KUBERNETES_SERVICE_PORT=6443\n",
      "PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com\n",
      "CMAKE_VERSION=3.14\n",
      "SCRAPE_PROMETHEUS_METRICS=yes\n",
      "MASTER_ADDR=train10qiryeueit-master-0\n",
      "HOSTNAME=train10qiryeueit-master-0\n",
      "MASTER_PORT=23456\n",
      "HOME=/root\n",
      "PAI_USER_ARGS=--max_depth 5 --eval_metric auc --n_estimator 100 --criterion gini\n",
      "PYTHONUNBUFFERED=0\n",
      "PAI_HPS_MAX_DEPTH=5\n",
      "NPROC_PER_NODE=0\n",
      "PAI_CONFIG_DIR=/ml/input/config/\n",
      "PAI_OUTPUT_CHECKPOINTS=/ml/output/checkpoints/\n",
      "WORLD_SIZE=1\n",
      "REGION_ID=cn-beijing\n",
      "RANK=0\n",
      "PAI_INPUT_TRAIN=/ml/input/data/train/train.csv\n",
      "TENANT_API_SERVER_URL=https://10.224.139.70:6443\n",
      "PAI_TRAINING_JOB_ID=train10qiryeueit\n",
      "PAI_OUTPUT_TENSORBOARD=/ml/output/tensorboard/\n",
      "KUBERNETES_PORT_443_TCP_ADDR=10.192.0.1\n",
      "PAI_OUTPUT_MODEL=/ml/output/model/\n",
      "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin\n",
      "PIP_INDEX_URL=https://mirrors.cloud.aliyuncs.com/pypi/simple\n",
      "KUBERNETES_PORT_443_TCP_PORT=443\n",
      "KUBERNETES_PORT_443_TCP_PROTO=tcp\n",
      "PAI_HPS_N_ESTIMATOR=100\n",
      "PAI_TRAINING_USE_ECI=true\n",
      "KUBERNETES_CONTAINER_RESOURCE_GPU=0\n",
      "PAI_INPUT_TEST=/ml/input/data/test/train.csv\n",
      "KUBERNETES_PORT_443_TCP=tcp://10.192.0.1:443\n",
      "KUBERNETES_SERVICE_PORT_HTTPS=443\n",
      "KUBERNETES_SERVICE_HOST=10.224.139.70\n",
      "PWD=/root\n",
      "PAI_HPS={\"criterion\":\"gini\",\"eval_metric\":\"auc\",\"max_depth\":\"5\",\"n_estimator\":\"100\"}\n",
      "PAI_HPS_CRITERION=gini\n",
      "PAI_HPS_EVAL_METRIC=auc\n",
      "PAI_WORKING_DIR=/ml/usercode/\n",
      "PAI_ODPS_CREDENTIAL=/ml/input/credential/odps.json\n",
      "Change to Working Directory, /ml/usercode/\n",
      "Installing dependencies from /ml/input/config//requirements.txt\n",
      "/usr/lib/python3/dist-packages/secretstorage/dhcrypto.py:15: CryptographyDeprecationWarning: int_from_bytes is deprecated, use int.from_bytes instead\n",
      "  from cryptography.utils import int_from_bytes\n",
      "/usr/lib/python3/dist-packages/secretstorage/util.py:19: CryptographyDeprecationWarning: int_from_bytes is deprecated, use int.from_bytes instead\n",
      "  from cryptography.utils import int_from_bytes\n",
      "Looking in indexes: https://mirrors.cloud.aliyuncs.com/pypi/simple\n",
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n",
      "User program launching\n",
      "-----------------------------------------------------------------\n",
      "Train dataset: train_shape=(398, 30)\n",
      "/usr/local/lib/python3.6/dist-packages/xgboost/sklearn.py:817: UserWarning: `eval_metric` in `fit` method is deprecated for better compatibility with scikit-learn, use `eval_metric` in constructor or`set_params` instead.\n",
      "  UserWarning,\n",
      "[0]\tvalidation_0-auc:0.99620\tvalidation_1-auc:0.94604\n",
      "[1]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.95106\n",
      "[2]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.95235\n",
      "[3]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.95393\n",
      "[4]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.95307\n",
      "[5]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97445\n",
      "[6]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97402\n",
      "[7]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97388\n",
      "[8]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97912\n",
      "[9]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97898\n",
      "[10]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97876\n",
      "[11]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97869\n",
      "[12]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97912\n",
      "[13]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97869\n",
      "[14]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97898\n",
      "[15]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97898\n",
      "[16]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97854\n",
      "[17]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97883\n",
      "[18]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97883\n",
      "[19]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97854\n",
      "[20]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97826\n",
      "[21]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97826\n",
      "[22]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97826\n",
      "[23]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97819\n",
      "[24]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97826\n",
      "[25]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97840\n",
      "[26]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97840\n",
      "[27]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97833\n",
      "[28]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97833\n",
      "[29]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97840\n",
      "[30]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97905\n",
      "[31]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97984\n",
      "[32]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.97984\n",
      "[33]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98213\n",
      "[34]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98213\n",
      "[35]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[36]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[37]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98192\n",
      "[38]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98278\n",
      "[39]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[40]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98278\n",
      "[41]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[42]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[43]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98292\n",
      "[44]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98292\n",
      "[45]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98292\n",
      "[46]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98292\n",
      "[47]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[48]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[49]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[50]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[51]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[52]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[53]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[54]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[55]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[56]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[57]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[58]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[59]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[60]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[61]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[62]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98220\n",
      "[63]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[64]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[65]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98220\n",
      "[66]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98220\n",
      "[67]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[68]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[69]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[70]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[71]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[72]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[73]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[74]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[75]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98220\n",
      "[76]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[77]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98263\n",
      "[78]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[79]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[80]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98249\n",
      "[81]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[82]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[83]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[84]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[85]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[86]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[87]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[88]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[89]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[90]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[91]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[92]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[93]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[94]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[95]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[96]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[97]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[98]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "[99]\tvalidation_0-auc:1.00000\tvalidation_1-auc:0.98235\n",
      "Save model succeed: model_path=/ml/output/model//model.json\n",
      "\n",
      "Training job (train10qiryeueit) succeeded, you can check the logs/metrics/output in  the console:\n",
      "https://pai.console.aliyun.com/?regionId=cn-beijing&workspaceId=90914#/training/jobs/train10qiryeueit\n",
      "oss://mlops-poc-beijing/pai/training_job/estimator_20240712_101220_265_j03jyg/model/\n"
     ]
    }
   ],
   "source": [
    "from pai.estimator import Estimator\n",
    "from pai.image import retrieve\n",
    "\n",
    "\n",
    "# 获取PAI提供的XGBoost训练镜像\n",
    "image_uri = retrieve(\"xgboost\", framework_version=\"latest\").image_uri\n",
    "print(image_uri)\n",
    "\n",
    "# 构建一个Estimator实例\n",
    "est = Estimator(\n",
    "    # 作业启动脚本\n",
    "    command=\"python train.py $PAI_USER_ARGS\",\n",
    "    # 作业脚本的本地文件夹路径，会被打包上传到OSS\n",
    "    source_dir=\"./xgb_src/\",\n",
    "    image_uri=image_uri,\n",
    "    # 作业超参: 会通过Command arguments的方式传递给到作业脚本\n",
    "    hyperparameters={\n",
    "        \"n_estimator\": 100,\n",
    "        \"criterion\": \"gini\",\n",
    "        \"max_depth\": 5,\n",
    "        \"eval_metric\": \"auc\",\n",
    "    },\n",
    "    # 作业使用的机器实例\n",
    "    instance_type=\"ecs.c6.large\",\n",
    ")\n",
    "\n",
    "# 使用上传到OSS的训练数据作为作业的数据\n",
    "est.fit(\n",
    "    inputs={\n",
    "        \"train\": train_data,  # train_data 将被挂载到`/ml/input/data/train`目录\n",
    "        \"test\": test_data,  # test_data 将被挂载到`/ml/input/data/test`目录\n",
    "    },\n",
    ")\n",
    "print(est.model_data())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step4: 部署模型\n",
    "\n",
    "以上训练获得模型，我们将使用[预置XGBoost Processor](https://help.aliyun.com/document_detail/470490.html)部署为一个在线服务。主要流程包括:\n",
    "\n",
    "1. 通过构建一个InferenceSpec\n",
    "\n",
    "InferenceSpec负责描述模型如何部署为一个在线服务，例如模型使用镜像部署，还是使用processor部署等。\n",
    "\n",
    "2. 构建Model对象\n",
    "\n",
    "Model对象可以直接部署服务，也可以通过`.register`注册到PAI的模型仓库。\n",
    "\n",
    "3. 使用`Model.deploy`部署在线服务。\n",
    "\n",
    "通过指定服务名称，机器实例类型，部署一个新的在线推理服务。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir xgb_infer/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing xgb_infer/serving.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile xgb_infer/serving.py\n",
    "\n",
    "\n",
    "import logging\n",
    "import os\n",
    "from typing import List\n",
    "\n",
    "import numpy as np\n",
    "import uvicorn\n",
    "import xgboost\n",
    "from fastapi import FastAPI\n",
    "from xgboost import XGBClassifier\n",
    "\n",
    "logging.basicConfig(\n",
    "    format=\"%(asctime)s: %(message)s\",\n",
    "    datefmt=\"%m/%d/%Y %I:%M:%S %p\",\n",
    "    level=logging.INFO,\n",
    ")\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "app = FastAPI()\n",
    "\n",
    "\n",
    "def load_model():\n",
    "    model_dir = os.environ.get(\"MODEL_MOUNT_PATH\", \"/eas/workspace/model\")\n",
    "    logger.info(\"model_dir: %s\", model_dir)\n",
    "    name = next((name for name in os.listdir(model_dir) if name.endswith(\"json\")), None)\n",
    "    logger.info(\"model dir files: %s\", os.listdir(model_dir))\n",
    "    if not name:\n",
    "        raise RuntimeError(\"Not found sklearn learn model under the model directory.\")\n",
    "\n",
    "    xgb_model = XGBClassifier()\n",
    "    xgb_model.load_model(os.path.join(model_dir, name))\n",
    "    return xgb_model\n",
    "\n",
    "\n",
    "model = load_model()\n",
    "\n",
    "\n",
    "@app.post(\"/\")\n",
    "def predict_v1(data: List):\n",
    "    global model\n",
    "    logger.info(\"API PredictV1 Invocation.\")\n",
    "    x = np.asarray(data)\n",
    "    y = model.predict(x)\n",
    "    return y.tolist()\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    logger.info(\"FastAPI server launching\")\n",
    "    logger.info(\"Environment Variables: %s\", os.environ)\n",
    "    logger.info(\"XGBoost Version: %s\", xgboost.__version__)\n",
    "    port = int(os.environ.get(\"LISTENING_PORT\", 8000))\n",
    "    uvicorn.run(app, host=\"0.0.0.0\", port=port)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.model import InferenceSpec, Model, container_serving_spec\n",
    "from pai.common.utils import random_str\n",
    "\n",
    "model_data = est.model_data()\n",
    "\n",
    "\n",
    "image_uri = retrieve(\"xgboost\", framework_version=\"latest\").image_uri\n",
    "\n",
    "inference_spec = container_serving_spec(\n",
    "    # 推理代码目录\n",
    "    source_dir=\"./xgb_infer\",\n",
    "    # 启动命令\n",
    "    command=\"python serving.py\",\n",
    "    # 推理镜像\n",
    "    image_uri=image_uri,\n",
    "    port=5000,\n",
    "    requirements=[\"uvicorn[standard]\", \"fastapi\"],\n",
    ")\n",
    "model = Model(\n",
    "    inference_spec=inference_spec,\n",
    "    model_data=model_data,\n",
    ")\n",
    "\n",
    "predictor = model.deploy(\n",
    "    service_name=\"example_xgb_{}\".format(random_str(6)),\n",
    "    instance_type=\"ecs.c6.xlarge\",\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step5: 测试在线服务\n",
    "\n",
    "`Model.deploy`方法返回一个 `Predictor` 对象，`Predictor.predict`方法支持向创建的推理服务发送推理请求，拿到预测结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0]\n"
     ]
    }
   ],
   "source": [
    "# print(p.service_name)\n",
    "\n",
    "test_x = test.drop([\"target\"], axis=1)\n",
    "\n",
    "res = predictor.predict(test_x.to_numpy().tolist())\n",
    "\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "在测试结束后，删除服务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_service()"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 1800
  },
  "kernelspec": {
   "display_name": "pai-dev-py36",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  },
  "vscode": {
   "interpreter": {
    "hash": "63703143536f433679c5464335316251eaa13807b3fcc3854dae32f2699871d6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
